-- |
-- AST traversal helpers
--
module Language.PureScript.AST.Traversals where

import Prelude
import Protolude (swap)

import Control.Monad ((<=<), (>=>))
import Control.Monad.Trans.State (StateT(..))

import Data.Foldable (fold)
import Data.Functor.Identity (runIdentity)
import Data.List (mapAccumL)
import Data.Maybe (mapMaybe)
import Data.List.NonEmpty qualified as NEL
import Data.Map qualified as M
import Data.Set qualified as S

import Language.PureScript.AST.Binders (Binder(..), binderNames)
import Language.PureScript.AST.Declarations (CaseAlternative(..), DataConstructorDeclaration(..), Declaration(..), DoNotationElement(..), Expr(..), Guard(..), GuardedExpr(..), TypeDeclarationData(..), TypeInstanceBody(..), pattern ValueDecl, ValueDeclarationData(..), mapTypeInstanceBody, traverseTypeInstanceBody)
import Language.PureScript.AST.Literals (Literal(..))
import Language.PureScript.Names (pattern ByNullSourcePos, Ident)
import Language.PureScript.Traversals (sndM, sndM', thirdM)
import Language.PureScript.TypeClassDictionaries (TypeClassDictionaryInScope(..))
import Language.PureScript.Types (Constraint(..), SourceType, mapConstraintArgs)

guardedExprM :: Applicative m
             => (Guard -> m Guard)
             -> (Expr -> m Expr)
             -> GuardedExpr
             -> m GuardedExpr
guardedExprM f g (GuardedExpr guards rhs) =
  GuardedExpr <$> traverse f guards <*> g rhs

mapGuardedExpr :: (Guard -> Guard)
               -> (Expr -> Expr)
               -> GuardedExpr
               -> GuardedExpr
mapGuardedExpr f g (GuardedExpr guards rhs) =
  GuardedExpr (fmap f guards) (g rhs)

litM :: Monad m => (a -> m a) -> Literal a -> m (Literal a)
litM go (ObjectLiteral as) = ObjectLiteral <$> traverse (sndM go) as
litM go (ArrayLiteral as) = ArrayLiteral <$> traverse go as
litM _ other = pure other

everywhereOnValues
  :: (Declaration -> Declaration)
  -> (Expr -> Expr)
  -> (Binder -> Binder)
  -> ( Declaration -> Declaration
     , Expr -> Expr
     , Binder -> Binder
     )
everywhereOnValues f g h = (f', g', h')
  where
  f' :: Declaration -> Declaration
  f' (DataBindingGroupDeclaration ds) = f (DataBindingGroupDeclaration (fmap f' ds))
  f' (ValueDecl sa name nameKind bs val) =
     f (ValueDecl sa name nameKind (fmap h' bs) (fmap (mapGuardedExpr handleGuard g') val))
  f' (BoundValueDeclaration sa b expr) = f (BoundValueDeclaration sa (h' b) (g' expr))
  f' (BindingGroupDeclaration ds) = f (BindingGroupDeclaration (fmap (\(name, nameKind, val) -> (name, nameKind, g' val)) ds))
  f' (TypeClassDeclaration sa name args implies deps ds) = f (TypeClassDeclaration sa name args implies deps (fmap f' ds))
  f' (TypeInstanceDeclaration sa na ch idx name cs className args ds) = f (TypeInstanceDeclaration sa na ch idx name cs className args (mapTypeInstanceBody (fmap f') ds))
  f' other = f other

  g' :: Expr -> Expr
  g' (Literal ss l) = g (Literal ss (lit g' l))
  g' (UnaryMinus ss v) = g (UnaryMinus ss (g' v))
  g' (BinaryNoParens op v1 v2) = g (BinaryNoParens (g' op) (g' v1) (g' v2))
  g' (Parens v) = g (Parens (g' v))
  g' (Accessor prop v) = g (Accessor prop (g' v))
  g' (ObjectUpdate obj vs) = g (ObjectUpdate (g' obj) (fmap (fmap g') vs))
  g' (ObjectUpdateNested obj vs) = g (ObjectUpdateNested (g' obj) (fmap g' vs))
  g' (Abs binder v) = g (Abs (h' binder) (g' v))
  g' (App v1 v2) = g (App (g' v1) (g' v2))
  g' (VisibleTypeApp v ty) = g (VisibleTypeApp (g' v) ty)
  g' (Unused v) = g (Unused (g' v))
  g' (IfThenElse v1 v2 v3) = g (IfThenElse (g' v1) (g' v2) (g' v3))
  g' (Case vs alts) = g (Case (fmap g' vs) (fmap handleCaseAlternative alts))
  g' (TypedValue check v ty) = g (TypedValue check (g' v) ty)
  g' (Let w ds v) = g (Let w (fmap f' ds) (g' v))
  g' (Do m es) = g (Do m (fmap handleDoNotationElement es))
  g' (Ado m es v) = g (Ado m (fmap handleDoNotationElement es) (g' v))
  g' (PositionedValue pos com v) = g (PositionedValue pos com (g' v))
  g' other = g other

  h' :: Binder -> Binder
  h' (ConstructorBinder ss ctor bs) = h (ConstructorBinder ss ctor (fmap h' bs))
  h' (BinaryNoParensBinder b1 b2 b3) = h (BinaryNoParensBinder (h' b1) (h' b2) (h' b3))
  h' (ParensInBinder b) = h (ParensInBinder (h' b))
  h' (LiteralBinder ss l) = h (LiteralBinder ss (lit h' l))
  h' (NamedBinder ss name b) = h (NamedBinder ss name (h' b))
  h' (PositionedBinder pos com b) = h (PositionedBinder pos com (h' b))
  h' (TypedBinder t b) = h (TypedBinder t (h' b))
  h' other = h other

  lit :: (a -> a) -> Literal a -> Literal a
  lit go (ArrayLiteral as) = ArrayLiteral (fmap go as)
  lit go (ObjectLiteral as) = ObjectLiteral (fmap (fmap go) as)
  lit _ other = other

  handleCaseAlternative :: CaseAlternative -> CaseAlternative
  handleCaseAlternative ca =
    ca { caseAlternativeBinders = fmap h' (caseAlternativeBinders ca)
       , caseAlternativeResult = fmap (mapGuardedExpr handleGuard g') (caseAlternativeResult ca)
       }

  handleDoNotationElement :: DoNotationElement -> DoNotationElement
  handleDoNotationElement (DoNotationValue v) = DoNotationValue (g' v)
  handleDoNotationElement (DoNotationBind b v) = DoNotationBind (h' b) (g' v)
  handleDoNotationElement (DoNotationLet ds) = DoNotationLet (fmap f' ds)
  handleDoNotationElement (PositionedDoNotationElement pos com e) = PositionedDoNotationElement pos com (handleDoNotationElement e)

  handleGuard :: Guard -> Guard
  handleGuard (ConditionGuard e) = ConditionGuard (g' e)
  handleGuard (PatternGuard b e) = PatternGuard (h' b) (g' e)

everywhereOnValuesTopDownM
  :: forall m
   . (Monad m)
  => (Declaration -> m Declaration)
  -> (Expr -> m Expr)
  -> (Binder -> m Binder)
  -> ( Declaration -> m Declaration
     , Expr -> m Expr
     , Binder -> m Binder
     )
everywhereOnValuesTopDownM f g h = (f' <=< f, g' <=< g, h' <=< h)
  where

  f' :: Declaration -> m Declaration
  f' (DataBindingGroupDeclaration ds) = DataBindingGroupDeclaration <$> traverse (f' <=< f) ds
  f' (ValueDecl sa name nameKind bs val) =
     ValueDecl sa name nameKind <$> traverse (h' <=< h) bs <*> traverse (guardedExprM handleGuard (g' <=< g)) val
  f' (BindingGroupDeclaration ds) = BindingGroupDeclaration <$> traverse (\(name, nameKind, val) -> (name, nameKind, ) <$> (g val >>= g')) ds
  f' (TypeClassDeclaration sa name args implies deps ds) = TypeClassDeclaration sa name args implies deps <$> traverse (f' <=< f) ds
  f' (TypeInstanceDeclaration sa na ch idx name cs className args ds) = TypeInstanceDeclaration sa na ch idx name cs className args <$> traverseTypeInstanceBody (traverse (f' <=< f)) ds
  f' (BoundValueDeclaration sa b expr) = BoundValueDeclaration sa <$> (h' <=< h) b <*> (g' <=< g) expr
  f' other = f other

  g' :: Expr -> m Expr
  g' (Literal ss l) = Literal ss <$> litM (g >=> g') l
  g' (UnaryMinus ss v) = UnaryMinus ss <$> (g v >>= g')
  g' (BinaryNoParens op v1 v2) = BinaryNoParens <$> (g op >>= g') <*> (g v1 >>= g') <*> (g v2 >>= g')
  g' (Parens v) = Parens <$> (g v >>= g')
  g' (Accessor prop v) = Accessor prop <$> (g v >>= g')
  g' (ObjectUpdate obj vs) = ObjectUpdate <$> (g obj >>= g') <*> traverse (sndM (g' <=< g)) vs
  g' (ObjectUpdateNested obj vs) = ObjectUpdateNested <$> (g obj >>= g') <*> traverse (g' <=< g) vs
  g' (Abs binder v) = Abs <$> (h binder >>= h') <*> (g v >>= g')
  g' (App v1 v2) = App <$> (g v1 >>= g') <*> (g v2 >>= g')
  g' (VisibleTypeApp v ty) = VisibleTypeApp <$> (g v >>= g') <*> pure ty
  g' (Unused v) = Unused <$> (g v >>= g')
  g' (IfThenElse v1 v2 v3) = IfThenElse <$> (g v1 >>= g') <*> (g v2 >>= g') <*> (g v3 >>= g')
  g' (Case vs alts) = Case <$> traverse (g' <=< g) vs <*> traverse handleCaseAlternative alts
  g' (TypedValue check v ty) = TypedValue check <$> (g v >>= g') <*> pure ty
  g' (Let w ds v) = Let w <$> traverse (f' <=< f) ds <*> (g v >>= g')
  g' (Do m es) = Do m <$> traverse handleDoNotationElement es
  g' (Ado m es v) = Ado m <$> traverse handleDoNotationElement es <*> (g v >>= g')
  g' (PositionedValue pos com v) = PositionedValue pos com <$> (g v >>= g')
  g' other = g other

  h' :: Binder -> m Binder
  h' (LiteralBinder ss l) = LiteralBinder ss <$> litM (h >=> h') l
  h' (ConstructorBinder ss ctor bs) = ConstructorBinder ss ctor <$> traverse (h' <=< h) bs
  h' (BinaryNoParensBinder b1 b2 b3) = BinaryNoParensBinder <$> (h b1 >>= h') <*> (h b2 >>= h') <*> (h b3 >>= h')
  h' (ParensInBinder b) = ParensInBinder <$> (h b >>= h')
  h' (NamedBinder ss name b) = NamedBinder ss name <$> (h b >>= h')
  h' (PositionedBinder pos com b) = PositionedBinder pos com <$> (h b >>= h')
  h' (TypedBinder t b) = TypedBinder t <$> (h b >>= h')
  h' other = h other

  handleCaseAlternative :: CaseAlternative -> m CaseAlternative
  handleCaseAlternative (CaseAlternative bs val) =
    CaseAlternative
      <$> traverse (h' <=< h) bs
      <*> traverse (guardedExprM handleGuard (g' <=< g)) val

  handleDoNotationElement :: DoNotationElement -> m DoNotationElement
  handleDoNotationElement (DoNotationValue v) = DoNotationValue <$> (g' <=< g) v
  handleDoNotationElement (DoNotationBind b v) = DoNotationBind <$> (h' <=< h) b <*> (g' <=< g) v
  handleDoNotationElement (DoNotationLet ds) = DoNotationLet <$> traverse (f' <=< f) ds
  handleDoNotationElement (PositionedDoNotationElement pos com e) = PositionedDoNotationElement pos com <$> handleDoNotationElement e

  handleGuard :: Guard -> m Guard
  handleGuard (ConditionGuard e) = ConditionGuard <$> (g' <=< g) e
  handleGuard (PatternGuard b e) = PatternGuard <$> (h' <=< h) b <*> (g' <=< g) e

everywhereOnValuesM
  :: forall m
   . (Monad m)
  => (Declaration -> m Declaration)
  -> (Expr -> m Expr)
  -> (Binder -> m Binder)
  -> ( Declaration -> m Declaration
     , Expr -> m Expr
     , Binder -> m Binder
     )
everywhereOnValuesM f g h = (f', g', h')
  where

  f' :: Declaration -> m Declaration
  f' (DataBindingGroupDeclaration ds) = (DataBindingGroupDeclaration <$> traverse f' ds) >>= f
  f' (ValueDecl sa name nameKind bs val) =
    ValueDecl sa name nameKind <$> traverse h' bs <*> traverse (guardedExprM handleGuard g') val >>= f
  f' (BindingGroupDeclaration ds) = (BindingGroupDeclaration <$> traverse (\(name, nameKind, val) -> (name, nameKind, ) <$> g' val) ds) >>= f
  f' (BoundValueDeclaration sa b expr) = (BoundValueDeclaration sa <$> h' b <*> g' expr) >>= f
  f' (TypeClassDeclaration sa name args implies deps ds) = (TypeClassDeclaration sa name args implies deps <$> traverse f' ds) >>= f
  f' (TypeInstanceDeclaration sa na ch idx name cs className args ds) = (TypeInstanceDeclaration sa na ch idx name cs className args <$> traverseTypeInstanceBody (traverse f') ds) >>= f
  f' other = f other

  g' :: Expr -> m Expr
  g' (Literal ss l) = (Literal ss <$> litM g' l) >>= g
  g' (UnaryMinus ss v) = (UnaryMinus ss <$> g' v) >>= g
  g' (BinaryNoParens op v1 v2) = (BinaryNoParens <$> g' op <*> g' v1 <*> g' v2) >>= g
  g' (Parens v) = (Parens <$> g' v) >>= g
  g' (Accessor prop v) = (Accessor prop <$> g' v) >>= g
  g' (ObjectUpdate obj vs) = (ObjectUpdate <$> g' obj <*> traverse (sndM g') vs) >>= g
  g' (ObjectUpdateNested obj vs) = (ObjectUpdateNested <$> g' obj <*> traverse g' vs) >>= g
  g' (Abs binder v) = (Abs <$> h' binder <*> g' v) >>= g
  g' (App v1 v2) = (App <$> g' v1 <*> g' v2) >>= g
  g' (VisibleTypeApp v ty) = (VisibleTypeApp <$> g' v <*> pure ty) >>= g
  g' (Unused v) = (Unused <$> g' v) >>= g
  g' (IfThenElse v1 v2 v3) = (IfThenElse <$> g' v1 <*> g' v2 <*> g' v3) >>= g
  g' (Case vs alts) = (Case <$> traverse g' vs <*> traverse handleCaseAlternative alts) >>= g
  g' (TypedValue check v ty) = (TypedValue check <$> g' v <*> pure ty) >>= g
  g' (Let w ds v) = (Let w <$> traverse f' ds <*> g' v) >>= g
  g' (Do m es) = (Do m <$> traverse handleDoNotationElement es) >>= g
  g' (Ado m es v) = (Ado m <$> traverse handleDoNotationElement es <*> g' v) >>= g
  g' (PositionedValue pos com v) = (PositionedValue pos com <$> g' v) >>= g
  g' other = g other

  h' :: Binder -> m Binder
  h' (LiteralBinder ss l) = (LiteralBinder ss <$> litM h' l) >>= h
  h' (ConstructorBinder ss ctor bs) = (ConstructorBinder ss ctor <$> traverse h' bs) >>= h
  h' (BinaryNoParensBinder b1 b2 b3) = (BinaryNoParensBinder <$> h' b1 <*> h' b2 <*> h' b3) >>= h
  h' (ParensInBinder b) = (ParensInBinder <$> h' b) >>= h
  h' (NamedBinder ss name b) = (NamedBinder ss name <$> h' b) >>= h
  h' (PositionedBinder pos com b) = (PositionedBinder pos com <$> h' b) >>= h
  h' (TypedBinder t b) = (TypedBinder t <$> h' b) >>= h
  h' other = h other

  handleCaseAlternative :: CaseAlternative -> m CaseAlternative
  handleCaseAlternative (CaseAlternative bs val) =
    CaseAlternative
      <$> traverse h' bs
      <*> traverse (guardedExprM handleGuard g') val

  handleDoNotationElement :: DoNotationElement -> m DoNotationElement
  handleDoNotationElement (DoNotationValue v) = DoNotationValue <$> g' v
  handleDoNotationElement (DoNotationBind b v) = DoNotationBind <$> h' b <*> g' v
  handleDoNotationElement (DoNotationLet ds) = DoNotationLet <$> traverse f' ds
  handleDoNotationElement (PositionedDoNotationElement pos com e) = PositionedDoNotationElement pos com <$> handleDoNotationElement e

  handleGuard :: Guard -> m Guard
  handleGuard (ConditionGuard e) = ConditionGuard <$> g' e
  handleGuard (PatternGuard b e) = PatternGuard <$> h' b <*> g' e

everythingOnValues
  :: forall r
   . (r -> r -> r)
  -> (Declaration -> r)
  -> (Expr -> r)
  -> (Binder -> r)
  -> (CaseAlternative -> r)
  -> (DoNotationElement -> r)
  -> ( Declaration -> r
     , Expr -> r
     , Binder -> r
     , CaseAlternative -> r
     , DoNotationElement -> r
     )
everythingOnValues (<>.) f g h i j = (f', g', h', i', j')
  where

  f' :: Declaration -> r
  f' d@(DataBindingGroupDeclaration ds) = foldl (<>.) (f d) (fmap f' ds)
  f' d@(ValueDeclaration vd) = foldl (<>.) (f d) (fmap h' (valdeclBinders vd) ++ concatMap (\(GuardedExpr grd v) -> fmap k' grd ++ [g' v]) (valdeclExpression vd))
  f' d@(BindingGroupDeclaration ds) = foldl (<>.) (f d) (fmap (\(_, _, val) -> g' val) ds)
  f' d@(TypeClassDeclaration _ _ _ _ _ ds) = foldl (<>.) (f d) (fmap f' ds)
  f' d@(TypeInstanceDeclaration _ _ _ _ _ _ _ _ (ExplicitInstance ds)) = foldl (<>.) (f d) (fmap f' ds)
  f' d@(BoundValueDeclaration _ b expr) = f d <>. h' b <>. g' expr
  f' d = f d

  g' :: Expr -> r
  g' v@(Literal _ l) = lit (g v) g' l
  g' v@(UnaryMinus _ v1) = g v <>. g' v1
  g' v@(BinaryNoParens op v1 v2) = g v <>. g' op <>. g' v1 <>. g' v2
  g' v@(Parens v1) = g v <>. g' v1
  g' v@(Accessor _ v1) = g v <>. g' v1
  g' v@(ObjectUpdate obj vs) = foldl (<>.) (g v <>. g' obj) (fmap (g' . snd) vs)
  g' v@(ObjectUpdateNested obj vs) = foldl (<>.) (g v <>. g' obj) (fmap g' vs)
  g' v@(Abs b v1) = g v <>. h' b <>. g' v1
  g' v@(App v1 v2) = g v <>. g' v1 <>. g' v2
  g' v@(VisibleTypeApp v' _) = g v <>. g' v'
  g' v@(Unused v1) = g v <>. g' v1
  g' v@(IfThenElse v1 v2 v3) = g v <>. g' v1 <>. g' v2 <>. g' v3
  g' v@(Case vs alts) = foldl (<>.) (foldl (<>.) (g v) (fmap g' vs)) (fmap i' alts)
  g' v@(TypedValue _ v1 _) = g v <>. g' v1
  g' v@(Let _ ds v1) = foldl (<>.) (g v) (fmap f' ds) <>. g' v1
  g' v@(Do _ es) = foldl (<>.) (g v) (fmap j' es)
  g' v@(Ado _ es v1) = foldl (<>.) (g v) (fmap j' es) <>. g' v1
  g' v@(PositionedValue _ _ v1) = g v <>. g' v1
  g' v = g v

  h' :: Binder -> r
  h' b@(LiteralBinder _ l) = lit (h b) h' l
  h' b@(ConstructorBinder _ _ bs) = foldl (<>.) (h b) (fmap h' bs)
  h' b@(BinaryNoParensBinder b1 b2 b3) = h b <>. h' b1 <>. h' b2 <>. h' b3
  h' b@(ParensInBinder b1) = h b <>. h' b1
  h' b@(NamedBinder _ _ b1) = h b <>. h' b1
  h' b@(PositionedBinder _ _ b1) = h b <>. h' b1
  h' b@(TypedBinder _ b1) = h b <>. h' b1
  h' b = h b

  lit :: r -> (a -> r) -> Literal a -> r
  lit r go (ArrayLiteral as) = foldl (<>.) r (fmap go as)
  lit r go (ObjectLiteral as) = foldl (<>.) r (fmap (go . snd) as)
  lit r _ _ = r

  i' :: CaseAlternative -> r
  i' ca@(CaseAlternative bs gs) =
    foldl (<>.) (i ca) (fmap h' bs ++ concatMap (\(GuardedExpr grd val) -> fmap k' grd ++ [g' val]) gs)

  j' :: DoNotationElement -> r
  j' e@(DoNotationValue v) = j e <>. g' v
  j' e@(DoNotationBind b v) = j e <>. h' b <>. g' v
  j' e@(DoNotationLet ds) = foldl (<>.) (j e) (fmap f' ds)
  j' e@(PositionedDoNotationElement _ _ e1) = j e <>. j' e1

  k' :: Guard -> r
  k' (ConditionGuard e) = g' e
  k' (PatternGuard b e) = h' b <>. g' e

everythingWithContextOnValues
  :: forall s r
   . s
  -> r
  -> (r -> r -> r)
  -> (s -> Declaration       -> (s, r))
  -> (s -> Expr              -> (s, r))
  -> (s -> Binder            -> (s, r))
  -> (s -> CaseAlternative   -> (s, r))
  -> (s -> DoNotationElement -> (s, r))
  -> ( Declaration       -> r
     , Expr              -> r
     , Binder            -> r
     , CaseAlternative   -> r
     , DoNotationElement -> r)
everythingWithContextOnValues s0 r0 (<>.) f g h i j = (f'' s0, g'' s0, h'' s0, i'' s0, j'' s0)
  where

  f'' :: s -> Declaration -> r
  f'' s d = let (s', r) = f s d in r <>. f' s' d

  f' :: s -> Declaration -> r
  f' s (DataBindingGroupDeclaration ds) = foldl (<>.) r0 (fmap (f'' s) ds)
  f' s (ValueDeclaration vd) = foldl (<>.) r0 (fmap (h'' s) (valdeclBinders vd) ++ concatMap (\(GuardedExpr grd v) -> fmap (k' s) grd ++ [g'' s v]) (valdeclExpression vd))
  f' s (BindingGroupDeclaration ds) = foldl (<>.) r0 (fmap (\(_, _, val) -> g'' s val) ds)
  f' s (TypeClassDeclaration _ _ _ _ _ ds) = foldl (<>.) r0 (fmap (f'' s) ds)
  f' s (TypeInstanceDeclaration _ _ _ _ _ _ _ _ (ExplicitInstance ds)) = foldl (<>.) r0 (fmap (f'' s) ds)
  f' _ _ = r0

  g'' :: s -> Expr -> r
  g'' s v = let (s', r) = g s v in r <>. g' s' v

  g' :: s -> Expr -> r
  g' s (Literal _ l) = lit g'' s l
  g' s (UnaryMinus _ v1) = g'' s v1
  g' s (BinaryNoParens op v1 v2) = g'' s op <>. g'' s v1 <>. g'' s v2
  g' s (Parens v1) = g'' s v1
  g' s (Accessor _ v1) = g'' s v1
  g' s (ObjectUpdate obj vs) = foldl (<>.) (g'' s obj) (fmap (g'' s . snd) vs)
  g' s (ObjectUpdateNested obj vs) = foldl (<>.) (g'' s obj) (fmap (g'' s) vs)
  g' s (Abs binder v1) = h'' s binder <>. g'' s v1
  g' s (App v1 v2) = g'' s v1 <>. g'' s v2
  g' s (VisibleTypeApp v _) = g'' s v
  g' s (Unused v) = g'' s v
  g' s (IfThenElse v1 v2 v3) = g'' s v1 <>. g'' s v2 <>. g'' s v3
  g' s (Case vs alts) = foldl (<>.) (foldl (<>.) r0 (fmap (g'' s) vs)) (fmap (i'' s) alts)
  g' s (TypedValue _ v1 _) = g'' s v1
  g' s (Let _ ds v1) = foldl (<>.) r0 (fmap (f'' s) ds) <>. g'' s v1
  g' s (Do _ es) = foldl (<>.) r0 (fmap (j'' s) es)
  g' s (Ado _ es v1) = foldl (<>.) r0 (fmap (j'' s) es) <>. g'' s v1
  g' s (PositionedValue _ _ v1) = g'' s v1
  g' _ _ = r0

  h'' :: s -> Binder -> r
  h'' s b = let (s', r) = h s b in r <>. h' s' b

  h' :: s -> Binder -> r
  h' s (LiteralBinder _ l) = lit h'' s l
  h' s (ConstructorBinder _ _ bs) = foldl (<>.) r0 (fmap (h'' s) bs)
  h' s (BinaryNoParensBinder b1 b2 b3) = h'' s b1 <>. h'' s b2 <>. h'' s b3
  h' s (ParensInBinder b) = h'' s b
  h' s (NamedBinder _ _ b1) = h'' s b1
  h' s (PositionedBinder _ _ b1) = h'' s b1
  h' s (TypedBinder _ b1) = h'' s b1
  h' _ _ = r0

  lit :: (s -> a -> r) -> s -> Literal a -> r
  lit go s (ArrayLiteral as) = foldl (<>.) r0 (fmap (go s) as)
  lit go s (ObjectLiteral as) = foldl (<>.) r0 (fmap (go s . snd) as)
  lit _ _ _ = r0

  i'' :: s -> CaseAlternative -> r
  i'' s ca = let (s', r) = i s ca in r <>. i' s' ca

  i' :: s -> CaseAlternative -> r
  i' s (CaseAlternative bs gs) = foldl (<>.) r0 (fmap (h'' s) bs ++ concatMap (\(GuardedExpr grd val) -> fmap (k' s) grd ++ [g'' s val]) gs)

  j'' :: s -> DoNotationElement -> r
  j'' s e = let (s', r) = j s e in r <>. j' s' e

  j' :: s -> DoNotationElement -> r
  j' s (DoNotationValue v) = g'' s v
  j' s (DoNotationBind b v) = h'' s b <>. g'' s v
  j' s (DoNotationLet ds) = foldl (<>.) r0 (fmap (f'' s) ds)
  j' s (PositionedDoNotationElement _ _ e1) = j'' s e1

  k' :: s -> Guard -> r
  k' s (ConditionGuard e) = g'' s e
  k' s (PatternGuard b e) = h'' s b <>. g'' s e

everywhereWithContextOnValues
  :: forall s
   . s
  -> (s -> Declaration       -> (s, Declaration))
  -> (s -> Expr              -> (s, Expr))
  -> (s -> Binder            -> (s, Binder))
  -> (s -> CaseAlternative   -> (s, CaseAlternative))
  -> (s -> DoNotationElement -> (s, DoNotationElement))
  -> (s -> Guard             -> (s, Guard))
  -> ( Declaration       -> Declaration
     , Expr              -> Expr
     , Binder            -> Binder
     , CaseAlternative   -> CaseAlternative
     , DoNotationElement -> DoNotationElement
     , Guard             -> Guard
     )
everywhereWithContextOnValues s f g h i j k = (runIdentity . f', runIdentity . g', runIdentity . h', runIdentity . i', runIdentity . j', runIdentity . k')
  where
  (f', g', h', i', j', k') = everywhereWithContextOnValuesM s (wrap f) (wrap g) (wrap h) (wrap i) (wrap j) (wrap k)
  wrap = ((pure .) .)

everywhereWithContextOnValuesM
  :: forall m s
   . (Monad m)
  => s
  -> (s -> Declaration       -> m (s, Declaration))
  -> (s -> Expr              -> m (s, Expr))
  -> (s -> Binder            -> m (s, Binder))
  -> (s -> CaseAlternative   -> m (s, CaseAlternative))
  -> (s -> DoNotationElement -> m (s, DoNotationElement))
  -> (s -> Guard             -> m (s, Guard))
  -> ( Declaration       -> m Declaration
     , Expr              -> m Expr
     , Binder            -> m Binder
     , CaseAlternative   -> m CaseAlternative
     , DoNotationElement -> m DoNotationElement
     , Guard             -> m Guard
     )
everywhereWithContextOnValuesM s0 f g h i j k = (f'' s0, g'' s0, h'' s0, i'' s0, j'' s0, k'' s0)
  where
  f'' s = uncurry f' <=< f s

  f' s (DataBindingGroupDeclaration ds) = DataBindingGroupDeclaration <$> traverse (f'' s) ds
  f' s (ValueDecl sa name nameKind bs val) =
    ValueDecl sa name nameKind <$> traverse (h'' s) bs <*> traverse (guardedExprM (k' s) (g'' s)) val
  f' s (BindingGroupDeclaration ds) = BindingGroupDeclaration <$> traverse (thirdM (g'' s)) ds
  f' s (TypeClassDeclaration sa name args implies deps ds) = TypeClassDeclaration sa name args implies deps <$> traverse (f'' s) ds
  f' s (TypeInstanceDeclaration sa na ch idx name cs className args ds) = TypeInstanceDeclaration sa na ch idx name cs className args <$> traverseTypeInstanceBody (traverse (f'' s)) ds
  f' _ other = return other

  g'' s = uncurry g' <=< g s

  g' s (Literal ss l) = Literal ss <$> lit g'' s l
  g' s (UnaryMinus ss v) = UnaryMinus ss <$> g'' s v
  g' s (BinaryNoParens op v1 v2) = BinaryNoParens <$> g'' s op <*> g'' s v1 <*> g'' s v2
  g' s (Parens v) = Parens <$> g'' s v
  g' s (Accessor prop v) = Accessor prop <$> g'' s v
  g' s (ObjectUpdate obj vs) = ObjectUpdate <$> g'' s obj <*> traverse (sndM (g'' s)) vs
  g' s (ObjectUpdateNested obj vs) = ObjectUpdateNested <$> g'' s obj <*> traverse (g'' s) vs
  g' s (Abs binder v) = Abs <$> h' s binder <*> g'' s v
  g' s (App v1 v2) = App <$> g'' s v1 <*> g'' s v2
  g' s (VisibleTypeApp v ty) = VisibleTypeApp <$> g'' s v <*> pure ty
  g' s (Unused v) = Unused <$> g'' s v
  g' s (IfThenElse v1 v2 v3) = IfThenElse <$> g'' s v1 <*> g'' s v2 <*> g'' s v3
  g' s (Case vs alts) = Case <$> traverse (g'' s) vs <*> traverse (i'' s) alts
  g' s (TypedValue check v ty) = TypedValue check <$> g'' s v <*> pure ty
  g' s (Let w ds v) = Let w <$> traverse (f'' s) ds <*> g'' s v
  g' s (Do m es) = Do m <$> traverse (j'' s) es
  g' s (Ado m es v) = Ado m <$> traverse (j'' s) es <*> g'' s v
  g' s (PositionedValue pos com v) = PositionedValue pos com <$> g'' s v
  g' _ other = return other

  h'' s = uncurry h' <=< h s

  h' s (LiteralBinder ss l) = LiteralBinder ss <$> lit h'' s l
  h' s (ConstructorBinder ss ctor bs) = ConstructorBinder ss ctor <$> traverse (h'' s) bs
  h' s (BinaryNoParensBinder b1 b2 b3) = BinaryNoParensBinder <$> h'' s b1 <*> h'' s b2 <*> h'' s b3
  h' s (ParensInBinder b) = ParensInBinder <$> h'' s b
  h' s (NamedBinder ss name b) = NamedBinder ss name <$> h'' s b
  h' s (PositionedBinder pos com b) = PositionedBinder pos com <$> h'' s b
  h' s (TypedBinder t b) = TypedBinder t <$> h'' s b
  h' _ other = return other

  lit :: (s -> a -> m a) -> s -> Literal a -> m (Literal a)
  lit go s (ArrayLiteral as) = ArrayLiteral <$> traverse (go s) as
  lit go s (ObjectLiteral as) = ObjectLiteral <$> traverse (sndM (go s)) as
  lit _ _ other = return other

  i'' s = uncurry i' <=< i s

  i' s (CaseAlternative bs val) = CaseAlternative <$> traverse (h'' s) bs <*> traverse (guardedExprM' s) val

  -- A specialized `guardedExprM` that keeps track of the context `s`
  -- after traversing `guards`, such that it's also exposed to `expr`.
  guardedExprM' :: s -> GuardedExpr -> m GuardedExpr
  guardedExprM' s (GuardedExpr guards expr) = do
    (guards', s') <- runStateT (traverse (StateT . goGuard) guards) s
    GuardedExpr guards' <$> g'' s' expr

  -- Like k'', but `s` is tracked.
  goGuard :: Guard -> s -> m (Guard, s)
  goGuard x s  = k s x >>= fmap swap . sndM' k'

  j'' s = uncurry j' <=< j s

  j' s (DoNotationValue v) = DoNotationValue <$> g'' s v
  j' s (DoNotationBind b v) = DoNotationBind <$> h'' s b <*> g'' s v
  j' s (DoNotationLet ds) = DoNotationLet <$> traverse (f'' s) ds
  j' s (PositionedDoNotationElement pos com e1) = PositionedDoNotationElement pos com <$> j'' s e1

  k'' s = uncurry k' <=< k s

  k' s (ConditionGuard e) = ConditionGuard <$> g'' s e
  k' s (PatternGuard b e) = PatternGuard <$> h'' s b <*> g'' s e

data ScopedIdent = LocalIdent Ident | ToplevelIdent Ident
  deriving (Show, Eq, Ord)

inScope :: Ident -> S.Set ScopedIdent -> Bool
inScope i s = (LocalIdent i `S.member` s) || (ToplevelIdent i `S.member` s)

everythingWithScope
  :: forall r
   . (Monoid r)
  => (S.Set ScopedIdent -> Declaration -> r)
  -> (S.Set ScopedIdent -> Expr -> r)
  -> (S.Set ScopedIdent -> Binder -> r)
  -> (S.Set ScopedIdent -> CaseAlternative -> r)
  -> (S.Set ScopedIdent -> DoNotationElement -> r)
  -> ( S.Set ScopedIdent -> Declaration -> r
     , S.Set ScopedIdent -> Expr -> r
     , S.Set ScopedIdent -> Binder -> r
     , S.Set ScopedIdent -> CaseAlternative -> r
     , S.Set ScopedIdent -> DoNotationElement -> r
     )
everythingWithScope f g h i j = (f'', g'', h'', i'', \s -> snd . j'' s)
  where
  f'' :: S.Set ScopedIdent -> Declaration -> r
  f'' s a = f s a <> f' s a

  f' :: S.Set ScopedIdent -> Declaration -> r
  f' s (DataBindingGroupDeclaration ds) =
    let s' = S.union s (S.fromList (map ToplevelIdent (mapMaybe getDeclIdent (NEL.toList ds))))
    in foldMap (f'' s') ds
  f' s (ValueDecl _ name _ bs val) =
    let s' = S.insert (ToplevelIdent name) s
        s'' = S.union s' (S.fromList (concatMap localBinderNames bs))
    in foldMap (h'' s') bs <> foldMap (l' s'') val
  f' s (BindingGroupDeclaration ds) =
    let s' = S.union s (S.fromList (NEL.toList (fmap (\((_, name), _, _) -> ToplevelIdent name) ds)))
    in foldMap (\(_, _, val) -> g'' s' val) ds
  f' s (TypeClassDeclaration _ _ _ _ _ ds) = foldMap (f'' s) ds
  f' s (TypeInstanceDeclaration _ _ _ _ _ _ _ _ (ExplicitInstance ds)) = foldMap (f'' s) ds
  f' _ _ = mempty

  g'' :: S.Set ScopedIdent -> Expr -> r
  g'' s a = g s a <> g' s a

  g' :: S.Set ScopedIdent -> Expr -> r
  g' s (Literal _ l) = lit g'' s l
  g' s (UnaryMinus _ v1) = g'' s v1
  g' s (BinaryNoParens op v1 v2) = g'' s op <> g'' s v1 <> g'' s v2
  g' s (Parens v1) = g'' s v1
  g' s (Accessor _ v1) = g'' s v1
  g' s (ObjectUpdate obj vs) = g'' s obj <> foldMap (g'' s . snd) vs
  g' s (ObjectUpdateNested obj vs) = g'' s obj <> foldMap (g'' s) vs
  g' s (Abs b v1) =
    let s' = S.union (S.fromList (localBinderNames b)) s
    in h'' s b <> g'' s' v1
  g' s (App v1 v2) = g'' s v1 <> g'' s v2
  g' s (VisibleTypeApp v _) = g'' s v
  g' s (Unused v) = g'' s v
  g' s (IfThenElse v1 v2 v3) = g'' s v1 <> g'' s v2 <> g'' s v3
  g' s (Case vs alts) = foldMap (g'' s) vs <> foldMap (i'' s) alts
  g' s (TypedValue _ v1 _) = g'' s v1
  g' s (Let _ ds v1) =
    let s' = S.union s (S.fromList (map LocalIdent (mapMaybe getDeclIdent ds)))
    in foldMap (f'' s') ds <> g'' s' v1
  g' s (Do _ es) = fold . snd . mapAccumL j'' s $ es
  g' s (Ado _ es v1) =
    let s' = S.union s (foldMap (fst . j'' s) es)
    in g'' s' v1
  g' s (PositionedValue _ _ v1) = g'' s v1
  g' _ _ = mempty

  h'' :: S.Set ScopedIdent -> Binder -> r
  h'' s a = h s a <> h' s a

  h' :: S.Set ScopedIdent -> Binder -> r
  h' s (LiteralBinder _ l) = lit h'' s l
  h' s (ConstructorBinder _ _ bs) = foldMap (h'' s) bs
  h' s (BinaryNoParensBinder b1 b2 b3) = foldMap (h'' s) [b1, b2, b3]
  h' s (ParensInBinder b) = h'' s b
  h' s (NamedBinder _ name b1) = h'' (S.insert (LocalIdent name) s) b1
  h' s (PositionedBinder _ _ b1) = h'' s b1
  h' s (TypedBinder _ b1) = h'' s b1
  h' _ _ = mempty

  lit :: (S.Set ScopedIdent -> a -> r) -> S.Set ScopedIdent -> Literal a -> r
  lit go s (ArrayLiteral as) = foldMap (go s) as
  lit go s (ObjectLiteral as) = foldMap (go s . snd) as
  lit _ _ _ = mempty

  i'' :: S.Set ScopedIdent -> CaseAlternative -> r
  i'' s a = i s a <> i' s a

  i' :: S.Set ScopedIdent -> CaseAlternative -> r
  i' s (CaseAlternative bs gs) =
    let s' = S.union s (S.fromList (concatMap localBinderNames bs))
    in foldMap (h'' s) bs <> foldMap (l' s') gs

  j'' :: S.Set ScopedIdent -> DoNotationElement -> (S.Set ScopedIdent, r)
  j'' s a = let (s', r) = j' s a in (s', j s a <> r)

  j' :: S.Set ScopedIdent -> DoNotationElement -> (S.Set ScopedIdent, r)
  j' s (DoNotationValue v) = (s, g'' s v)
  j' s (DoNotationBind b v) =
    let s' = S.union (S.fromList (localBinderNames b)) s
    in (s', h'' s b <> g'' s v)
  j' s (DoNotationLet ds) =
    let s' = S.union s (S.fromList (map LocalIdent (mapMaybe getDeclIdent ds)))
    in (s', foldMap (f'' s') ds)
  j' s (PositionedDoNotationElement _ _ e1) = j'' s e1

  k' :: S.Set ScopedIdent -> Guard -> (S.Set ScopedIdent, r)
  k' s (ConditionGuard e) = (s, g'' s e)
  k' s (PatternGuard b e) =
    let s' = S.union (S.fromList (localBinderNames b)) s
    in (s', h'' s b <> g'' s' e)

  l' s (GuardedExpr [] e) = g'' s e
  l' s (GuardedExpr (grd:gs) e) =
    let (s', r) = k' s grd
    in r <> l' s' (GuardedExpr gs e)

  getDeclIdent :: Declaration -> Maybe Ident
  getDeclIdent (ValueDeclaration vd) = Just (valdeclIdent vd)
  getDeclIdent (TypeDeclaration td) = Just (tydeclIdent td)
  getDeclIdent _ = Nothing

  localBinderNames = map LocalIdent . binderNames

accumTypes
  :: (Monoid r)
  => (SourceType -> r)
  -> ( Declaration -> r
     , Expr -> r
     , Binder -> r
     , CaseAlternative -> r
     , DoNotationElement -> r
     )
accumTypes f = everythingOnValues mappend forDecls forValues forBinders (const mempty) (const mempty)
  where
  forDecls (DataDeclaration _ _ _ args dctors) =
    foldMap (foldMap f . snd) args <>
    foldMap (foldMap (f . snd) . dataCtorFields) dctors
  forDecls (ExternDataDeclaration _ _ ty) = f ty
  forDecls (ExternDeclaration _ _ ty) = f ty
  forDecls (TypeClassDeclaration _ _ args implies _ _) =
    foldMap (foldMap (foldMap f)) args <>
    foldMap (foldMap f . constraintArgs) implies
  forDecls (TypeInstanceDeclaration _ _ _ _ _ cs _ tys _) =
    foldMap (foldMap f . constraintArgs) cs <> foldMap f tys
  forDecls (TypeSynonymDeclaration _ _ args ty) =
    foldMap (foldMap f . snd) args <>
    f ty
  forDecls (KindDeclaration _ _ _ ty) = f ty
  forDecls (TypeDeclaration td) = f (tydeclType td)
  forDecls _ = mempty

  forValues (TypeClassDictionary c _ _) = foldMap f (constraintArgs c)
  forValues (DeferredDictionary _ tys) = foldMap f tys
  forValues (TypedValue _ _ ty) = f ty
  forValues (VisibleTypeApp _ ty) = f ty
  forValues _ = mempty

  forBinders (TypedBinder ty _) = f ty
  forBinders _ = mempty

-- |
-- Map a function over type annotations appearing inside a value
--
overTypes :: (SourceType -> SourceType) -> Expr -> Expr
overTypes f = let (_, f', _) = everywhereOnValues id g id in f'
  where
  g :: Expr -> Expr
  g (TypedValue checkTy val t) = TypedValue checkTy val (f t)
  g (TypeClassDictionary c sco hints) =
    TypeClassDictionary
      (mapConstraintArgs (fmap f) c)
      (updateCtx sco)
      hints
  g other = other
  updateDict fn dict = dict { tcdInstanceTypes = fn (tcdInstanceTypes dict) }
  updateScope = fmap . fmap . fmap . fmap $ updateDict $ fmap f
  updateCtx = M.alter updateScope ByNullSourcePos
